package com.lw.leetcode.tree.b;

import java.util.*;

/**
 * Created with IntelliJ IDEA.
 * 1443. 收集树上所有苹果的最少时间
 *
 * @author liw
 * @version 1.0
 * @date 2022/4/7 16:44
 */
public class MinTime {


    public static void main(String[] args) {
        MinTime test = new MinTime();


        // 8
//        int[][] arr = {{0,1},{0,2},{1,4},{1,5},{2,3},{2,6}};
//        List<Boolean> hasApple = Arrays.asList(false,false,true,false,true,true,false);

        // 6
//        int[][] arr = {{0,1},{0,2},{1,4},{1,5},{2,3},{2,6}};
//        List<Boolean> hasApple = Arrays.asList(false,false,true,false,false,true,false);

        // 0
//        int[][] arr = {{0,1},{0,2},{1,4},{1,5},{2,3},{2,6}};
//        List<Boolean> hasApple = Arrays.asList( false,false,false,false,false,false,false);

        // 4
        int[][] arr = {{0, 2}, {1, 2}};
        List<Boolean> hasApple = Arrays.asList(false, true, false);

        int i = test.minTime(0, arr, hasApple);
        System.out.println(i);
    }


    private Map<Integer, List<Integer>> map = new HashMap<>();
    private List<Boolean> hasApple;

    public int minTime(int n, int[][] edges, List<Boolean> hasApple) {
        this.hasApple = hasApple;
        for (int[] edge : edges) {
            map.computeIfAbsent(edge[0], v -> new ArrayList<>()).add(edge[1]);
            map.computeIfAbsent(edge[1], v -> new ArrayList<>()).add(edge[0]);
        }
        List<Integer> list = map.get(0);
        for (Integer integer : list) {
            find(0, integer);
        }
        int sum = find(0);
        if (hasApple.get(0) || sum > 0) {
            sum -= 2;
        }
        return sum;
    }

    private int find(int item) {
        List<Integer> list = map.get(item);
        int count = 0;
        if (list != null) {
            for (Integer integer : list) {
                count += find(integer);
            }
        }
        if (hasApple.get(item) || count > 0) {
            count += 2;
        }
        return count;
    }

    private void find(int last, int item) {
        List<Integer> list = map.get(item);
        for (int i = list.size() - 1; i >= 0; i--) {
            if (list.get(i) == last) {
                list.remove(i);
            } else {
                find(item, list.get(i));
            }
        }
        if (list.isEmpty()) {
            map.put(item, null);
        }
    }

}
